import json
import math
from functools import wraps
from copy import copy
from typing import NamedTuple
from enum import Enum

import numpy as np
import h5py
import torch

from ..utils import SingularGramError


KERNELS = {}

INDUCERS = {'none': None}


def register_kernel(name):
    def wrapped(kernel):
        kernel.__serialized_name__ = name
        KERNELS[name] = kernel
        return kernel
    return wrapped


def register_inducer(name):
    def wrapped(func):
        INDUCERS[name] = func
        return func
    return wrapped


def atomize(obj):
    if isinstance(obj, torch.Tensor):
        return obj.detach().item()
    return obj


def cdiff(x1, x2):
    return x1[..., :, None, :] - x2[..., None, :, :]


def tdiff(x1, x2, p=math.tau, lscale=1):
    ''' Function to compute torus distance

    Parameters
    ----------
    x1 : :py:obj:`torch.Tensor`
        First set of points with shape (number of samples x dimensions).
    x2 : :py:obj:`torch.Tensor`
        Second set of points with shape (number of samples x dimensions).
    p : float
        Periodicity of the manifold (torus).
    lscale : float
        Lenghtscale parameter.
    Returns
    -------
    :py:obj:`torch.Tensor`
        Torus distance for all points in the 2 sets.
    '''
    torus_dist = (p / math.pi / lscale * torch.sin(math.pi / p * cdiff(x1, x2))) ** 2
    return torus_dist.sum(dim=-1) ** .5


def slog(x):
    return x.sign(), x.abs().log()


class MultivariateNormal:
    '''MultivariateNormal, possibly with independent variables.

    Parameters
    ----------
    mean : obj:`torch.Tensor`
        One-dimensional Tensor with means. May have arbitrary batch dimension matching `covar`.
    covar : obj:`torch.Tensor`
        Two-dimensional Tensor of covariance matrix, or one-dimensional Tensor with the diagonal elements of the
        covariance matrix. If two-dimensional, must be square (indicates covariance) or have one dimension equal to 1
        (indicates diagonal). May have an arbiratry batch dimension, in which case two dimensions must be used to
        describe the covariance diagonal or full covariance.
    cholesky : obj:`torch.Tensor`, optional
        Two-dimensional Tensor of the cholesky decomposition of the covariance matrix, or one-dimensional Tensor with
        the diagonal elements of the cholesky matrix. If two-dimensional, must be square (indicates covariance) or
        have one dimension equal to 1 (indicates diagonal). May have an arbiratry batch dimension, in which case two
        dimensions must be used to describe the cholesky diagonal or full cholesky.
    '''
    def __init__(self, mean, covar, cholesky=None):
        if covar.ndim > 1 and not covar.shape[-2] == covar.shape[-1] and 1 not in covar.shape[-2:]:
            raise TypeError('Shape of covariance is invalid!')
        zipshape = list(zip(covar.shape[:-2], mean.shape[:-1]))
        if (
            not (
                (covar.ndim in (1, 2) and mean.ndim == 1)
                or all(1 in {a, b} or a == b for a, b in zipshape)
            )
            or (max(covar.shape[-2:]) != mean.shape[-1])
        ):
            raise TypeError('Shapes of covariance and mean do not match!')

        if cholesky is not None and cholesky.shape != covar.shape:
            raise TypeError('Shapes of covariance and cholesky do not match!')

        if any(a != b for a, b in zipshape):
            shape = [max(a, b) for a, b in zipshape]
            mean = mean.expand((*shape, mean.shape[-1]))
            covar = covar.expand((*shape, *covar.shape[-2:]))

        self._mean = mean
        self._covar = covar
        self._covrank = None
        self._cholesky = cholesky

    def __getitem__(self, key):
        return MultivariateNormal(self._mean[key], self._covar[key])

    @property
    def mean(self):
        '''Return the means.'''
        return self._mean

    @property
    def covar(self):
        '''Return the covariance matrix.'''
        if self._covar.ndim == 1:
            return torch.diagflat(self._covar)
        if 1 in self._covar.shape[-2:]:
            return torch.diag_embed(self._covar.flatten(start_dim=-2))
        return self._covar

    @property
    def var(self):
        '''Return the diagonal of the covariance matrix.'''
        if self._covar.ndim == 1:
            return self._covar
        if 1 in self._covar.shape[-2:]:
            return self._covar.flatten(start_dim=-2)
        return torch.diagonal(self._covar, dim1=-2, dim2=-1)

    @property
    def std(self):
        '''Return the square root of the diagonal of the covariance matrix.'''
        return self.var ** .5

    @property
    def cholesky(self):
        if self._cholesky is None:
            if self._covar.ndim == 1 or self._covar.ndim == 2 and 1 in self._covar.shape:
                self._cholesky = self._covar ** .5
            else:
                self._cholesky = torch.linalg.cholesky(self.covar)
        return self._cholesky

    def sample(self, size=1, randn=torch.randn):
        '''Return samples from the distribution.'''
        if isinstance(size, int):
            size = (size,)
        samples = randn((*size, *self._mean.shape))
        if self._covar.ndim == 1 or self._covar.ndim == 2 and 1 in self._covar.shape:
            samples = samples * self.std + self._mean
        else:
            samples = (self.cholesky @ samples[..., None]).squeeze(-1) + self._mean
        return samples

    @property
    def stats(self):
        '''Return the means and the covariance matrix.'''
        return self.mean, self.covar

    @property
    def diag_stats(self):
        '''Return the means and the diagonal of the covariance matrix.'''
        return self.mean, self.var

    @property
    def covrank(self):
        '''Return the rank of the covariance matrix.'''
        if self._covrank is None:
            self._covrank = torch.linalg.matrix_rank(self.covar, atol=1e-10)
        return self._covrank

    def pdf(self, x):
        '''Probability density function, assuming independent variables.

        Parameters
        ----------
        x : float or obj:`torch.Tensor`
            One-dimensional tensor of inputs.

        Returns
        -------
        obj:`torch.Tensor`
            One-dimensional tensor with probabilities.
        '''
        mean, var = self.diag_stats
        var = var.clip(min=1e-30)
        return (-(x - mean) ** 2. / var / 2.).exp() / (var * 2 * math.pi) ** .5

    def cdf(self, x):
        '''Cumulative distribution function, assuming independent variables.

        Parameters
        ----------
        x : float or obj:`torch.Tensor`
            One-dimensional tensor of inputs.

        Returns
        -------
        obj:`torch.Tensor`
            One-dimensional tensor with cumulative probabilities.
        '''
        mean, var = self.diag_stats
        var = var.clip(min=1e-30)
        return (1. + torch.erf((x - mean) / (var * 2) ** .5)) / 2.


def flatargs(func):
    '''Create a wrapper which flattens the arguments before passing them to the true function.'''
    @wraps(func)
    def wrapped(self, x1, x2):
        return func(self, self.flat(x1), self.flat(x2))
    return wrapped


class Kernel:
    '''Base class for kernels.'''
    __kernel_params__ = tuple()

    def __init__(self, n_feature_dim=2):
        self.n_feature_dim = 2

    def flat(self, x):
        return x.flatten(start_dim=-self.n_feature_dim)

    def __call__(self, x1, x2):
        '''Compute kernel values.

        Parameters
        ----------
        x1 : :py:obj:`torch.Tensor`
            First set of points with shape (number of samples x dimensions).
        x2 : :py:obj:`torch.Tensor`
            Second set of points with shape (number of samples x dimensions).

        Returns
        -------
        obj:`torch.tensor`
            Gram matrix of kernel with shape (number of samples x number of samples).

        '''

    def __repr__(self):
        params = ', '.join(f'{key}={atomize(getattr(self, key))}' for key in self.__kernel_params__)
        return f'{self.__class__.__name__}({params})'

    def diag(self, x1):
        '''Compute diagonal kernel values.

        Parameters
        ----------
        x1 : :py:obj:`torch.Tensor`
            Set of points with shape (number of samples x dimensions).

        Returns
        -------
        obj:`torch.tensor`
            Diagonal Gram matrix of kernel of (x1, x1) with shape (number of samples).

        '''

    def param_dict(self):
        return {key: atomize(getattr(self, key)) for key in self.__kernel_params__}

    def serialize(self):
        return json.dumps((self.__serialized_name__, self.param_dict()))

    @staticmethod
    def deserialize(string):
        name, kwargs = json.loads(string)
        return KERNELS[name](**kwargs)

    def parameters(self):
        return (
            getattr(self, key) for key in self.__kernel_params__ if isinstance(getattr(self, key), torch.Tensor)
        )


class VQEKernelGrad(Kernel):
    __kernel_params__ = ('gamma',)

    def __init__(self, sigma_0=1.0, gamma=2.0, n_feature_dim=2, order=1):
        super().__init__(n_feature_dim=n_feature_dim)
        self.sigma_0 = sigma_0
        self.gamma = torch.tensor(gamma)
        self.sigma_0_sq = sigma_0 ** 2
        self.order = order

    def __call__(self, x1, x2):
        shape = x1.shape[-self.n_feature_dim:]
        x1, x2 = self.flat(x1), self.flat(x2)
        diff = cdiff(x1, x2)

        trig = torch.sin if self.order == 1 else torch.cos

        # TODO: check this
        gram_matrix = (self.gamma ** 2 + 2 * torch.cos(diff)) / (2 + self.gamma ** 2)
        prefix = 2 * trig(diff) / (2 + self.gamma ** 2)
        kern = self.sigma_0_sq * (prefix * (gram_matrix.prod(dim=-1, keepdim=True)/gram_matrix))

        # gram_matrix2 = (self.gamma ** 2 + 2 * torch.cos(diff)) / (2 + self.gamma ** 2)
        # prefix2 = (2 * torch.sin(diff)) / (2 + self.gamma ** 2)
        # kern2 = self.sigma_0_sq * prefix2 * (
        #     gram_matrix2[..., :, None] ** (1 - torch.eye(gram_matrix2.shape[-1]))
        # ).prod(axis=-1)

        kern = kern.reshape(kern.shape[:-1] + shape)
        return torch.movedim(kern, (-self.n_feature_dim - 1, -self.n_feature_dim - 2), (-1, -2))

    def diag(self, x1):
        if self.order == 1:
            return torch.zeros(x1.shape[:-self.n_feature_dim], dtype=x1.dtype)
        return torch.full(x1.shape[:-self.n_feature_dim], 2 * self.sigma_0_sq / (self.gamma ** 2 + 2), dtype=x1.dtype)


@register_kernel('vqe')
class VQEKernel(Kernel):
    __kernel_params__ = ('sigma_0', 'gamma', 'n_feature_dim')
    _grad_kernel = VQEKernelGrad

    def __init__(self, sigma_0=1.0, gamma=2.0, n_feature_dim=2):
        super().__init__(n_feature_dim=n_feature_dim)
        self.sigma_0 = sigma_0
        self.gamma = torch.tensor(gamma)
        self.sigma_0_sq = sigma_0 ** 2
        self._grad = None
        self._grad2 = None

    @flatargs
    def __call__(self, x1, x2):
        # gram_matrix = (self.gamma ** 2 + 2 * torch.cos(cdiff(x1, x2))).log() - (2 + self.gamma ** 2).log()
        # kern = self.sigma_0_sq * gram_matrix.sum(axis=-1).exp()

        gram_matrix = (self.gamma ** 2 + 2 * torch.cos(cdiff(x1, x2))) / (2 + self.gamma ** 2)
        kern = self.sigma_0_sq * gram_matrix.prod(dim=-1)
        return kern

    def diag(self, x1):
        zerodiff = self.flat(torch.zeros(x1.shape, dtype=x1.dtype))
        gram_matrix = (self.gamma ** 2 + 2 * torch.cos(zerodiff)) / (2 + self.gamma ** 2)
        kern = self.sigma_0_sq * gram_matrix.prod(dim=-1)
        return kern

    @property
    def grad(self):
        if self._grad is None:
            self._grad = self._grad_kernel(self.sigma_0, self.gamma.detach().item(), self.n_feature_dim, order=1)
        return self._grad

    @property
    def grad2(self):
        if self._grad2 is None:
            self._grad2 = self._grad_kernel(self.sigma_0, self.gamma.detach().item(), self.n_feature_dim, order=2)
        return self._grad2


@register_kernel('rbf')
class RBFKernel(Kernel):
    __kernel_params__ = ('gamma',)

    def __init__(self, sigma_0=1.0, gamma=1.0, n_feature_dim=2):
        super().__init__(n_feature_dim=n_feature_dim)
        self.sigma_0 = sigma_0
        self.gamma = torch.tensor(gamma)
        self.sigma_0_sq = sigma_0 ** 2

    @flatargs
    def __call__(self, x1, x2):
        exp_term = (cdiff(x1, x2) / self.gamma) ** 2
        kern = self.sigma_0_sq * torch.exp(-0.5 * exp_term.sum(-1))
        return kern

    def diag(self, x1):
        return torch.full(x1.shape[:-self.n_feature_dim], self.sigma_0_sq, dtype=x1.dtype)


class MeasureGoal(str, Enum):
    '''Enum to indicate the goal of measurements added to the GP. This is used to implement more advanced inducing
    point algorithms.
    '''
    INIT = 'init'
    PLAIN = 'plain'
    LEFT = 'left'
    RIGHT = 'right'
    PIVOT = 'pivot'
    DISTILL = 'distill'

    def __repr__(self):
        return self.name


class MeasureMeta(NamedTuple):
    '''Meta data of Measurements for GPs used for advanced inducing point algorithms.'''
    step: int = 0
    goal: MeasureGoal = MeasureGoal.INIT
    readout: int = None


@register_inducer('last_slack')
class LastSlackInducer:
    '''Inducing point algorithm that discards all except the `retain` recent measurements whenever the number of
    measurements exceeds `retain` + `slack`.
    '''
    def __init__(self, retain: int, slack: int):
        self.retain = retain
        self.slack = slack

    def __call__(self, model):
        if len(model) > self.retain + self.slack:
            model.slice(slice(-self.retain, None))


@register_inducer('last_slack_pivot')
class LastSlackPivotInducer:
    '''Estimate the pivot point at the BO-step at the state when the GP had `retain` less training samples. Create a
    temporary gaussian process from samples measured only at previous bayesian optimization steps. Compute the
    predictive posterior mean and variance at the pivot point, and discard all training samples used for the temporary
    GP from the original GP. Finally, add the predictive posterior mean and variance at the pivot point as a new sample
    to the original GP.
    '''

    def __init__(self, retain: int, slack: int):
        self.retain = retain
        self.slack = slack

    def __call__(self, model):
        if len(model.x_train) > self.retain + self.slack:
            if model.meta[-1].goal == MeasureGoal.DISTILL:
                model.meta = [model.meta[-1]] + model.meta[:-1]
                for attr in ["x_train", "y_train", "cov_inv_y", "y_var"]:
                    setattr(model, attr, torch.cat((getattr(model, attr)[-1:], getattr(model, attr)[:-1])))
                model.covar = model.kernel(model.x_train, model.x_train) + model.y_var.diag()
                model.cholesky = torch.linalg.cholesky(model.covar, upper=False)
            new_first = model.meta[-self.retain]

            if new_first.step is None:
                raise RuntimeError('No valid metadata to identify pivot point!')

            # find the measurements at the previous step
            shift_indices = {}
            last_index = 0
            for index, meta in enumerate(model.meta):
                # the following assumes the training data is sorted in ascending order
                if meta.step is None:
                    continue
                if meta.step > new_first.step:
                    break
                if meta.step < new_first.step:
                    last_index = index
                if meta.step == new_first.step and meta.goal not in shift_indices:
                    shift_indices[meta.goal] = index

            if new_first.step == model.meta[last_index].step:
                # cannot induce when we have only a single group
                return

            # get or reconstruct the pivot point at the previous step
            if MeasureGoal.PIVOT in shift_indices:
                x_pivot = model.x_train[[shift_indices[MeasureGoal.PIVOT]]]
            elif all(key in shift_indices for key in (MeasureGoal.LEFT, MeasureGoal.RIGHT)):
                x_left = model.x_train[shift_indices[MeasureGoal.LEFT]]
                x_right = model.x_train[shift_indices[MeasureGoal.RIGHT]]
                k_best = ((x_right - x_left).abs() > 1e-9)
                if x_right[k_best].item() < x_left[k_best].item():
                    x_right = x_right + k_best * math.tau
                x_pivot = ((x_left + x_right) / 2.)[None] % math.tau
            else:
                raise RuntimeError('Could not find or reconstruct pivot point! Try increasing slack.')

            # compute posterior of gp using all points that will be discarded
            d_pivot = model[:last_index + 1].posterior(x_pivot, diag=True)

            # discard samples
            discarded_readout = sum(elem.readout for elem in model.meta[:last_index + 1] if elem.readout is not None)
            model.slice(slice(last_index + 1, None))
            # add predicted values for pivot point to GP
            model.update(
                x_pivot,
                d_pivot.mean,
                d_pivot.var,
                meta=[MeasureMeta(step=model.meta[0].step, goal=MeasureGoal.DISTILL, readout=discarded_readout)]
            )


@register_inducer('last_slack_pivot_grad')
class LastSlackPivotInducerGrad:
    '''Estimate the pivot point at the BO-step at the state when the GP had `retain` less training samples. Create a
    temporary gaussian process from samples measured only at previous bayesian optimization steps. Compute the
    predictive posterior mean and variance of the GP posterior and of the GP gradient-posterior at the pivot point, and
    discard all training samples used for the temporary GP from the original GP. Finally, add the predictive posterior
    mean and variance at the pivot point as a new sample to the original GP.
    '''

    def __init__(self, retain: int, slack: int):
        self.retain = retain
        self.slack = slack

    def __call__(self, model):
        if len(model.x_train) > self.retain + self.slack:
            if model.meta[-1].goal == MeasureGoal.DISTILL:
                model.meta = [model.meta[-1]] + model.meta[:-1]
                for attr in ["x_train", "y_train", "cov_inv_y", "y_var"]:
                    setattr(model, attr, torch.cat((getattr(model, attr)[-1:], getattr(model, attr)[:-1])))
                model.covar = model.kernel(model.x_train, model.x_train) + model.y_var.diag()
                model.cholesky = torch.linalg.cholesky(model.covar, upper=False)
            new_first = model.meta[-self.retain]

            if new_first.step is None:
                raise RuntimeError('No valid metadata to identify pivot point!')

            # find the measurements at the previous step
            shift_indices = {}
            last_index = 0
            for index, meta in enumerate(model.meta):
                # the following assumes the training data is sorted in ascending order
                if meta.step is None:
                    continue
                if meta.step > new_first.step:
                    break
                if meta.step < new_first.step:
                    last_index = index
                if meta.step == new_first.step and meta.goal not in shift_indices:
                    shift_indices[meta.goal] = index
                    # these should be the indices that follow the point in which the iteration touches the new_first
                    # point but because of the way GradCoRe works, meta.step == new_first.step happens before we meet
                    # the new_first point -> shift indices are wrong these become the indices of the first two points
                    # measured in the same step as the new_first point the points we need to use to reconstruct the
                    # pivot are: model.x_train[shift_indices[MeasureGoal.LEFT]:-self.retain]

            if new_first.step == model.meta[last_index].step:
                # cannot induce when we have only a single group
                return

            # get or reconstruct the pivot point at the previous step
            if MeasureGoal.PIVOT in shift_indices:
                x_pivot = model.x_train[[shift_indices[MeasureGoal.PIVOT]]]
            elif all(key in shift_indices for key in (MeasureGoal.LEFT, MeasureGoal.RIGHT)):
                condensed = model.x_train[shift_indices[MeasureGoal.LEFT]:-self.retain]
                for i in range(0, len(condensed), 2):
                    k_best = ((condensed[i+1] - condensed[i]).abs() > 1e-9)
                    if condensed[i+1][k_best].item() < condensed[i][k_best].item():
                        condensed[i+1] = condensed[i+1] + k_best * math.tau
                    x_pivot = ((condensed[i] + condensed[i+1]) / 2.)[None] % math.tau
            else:
                raise RuntimeError('Could not find or reconstruct pivot point! Try increasing slack.')

            # compute posterior of gp using all points that will be discarded
            last_index += len(model[shift_indices[MeasureGoal.LEFT]:-self.retain])
            # this has to be updated with the number of points belonging to the first step that are condensed to form
            # the pivot point in SMO this is not the case, since the number of condensed points is always 2 and they
            # correspond to training points indicated by shift_indices this has the effect of making new_first.step =
            # model.meta[last_index] + 1 which in turns falsifies new_first.step == model.meta[last_index] in GradCoRe
            # this is not the case: the new_first can belong to the same group of the last_index
            d_grad_pivot = model[:last_index + 1].posterior_grad(x_pivot, diag=True)

            # discard samples
            discarded_readout = sum(elem.readout for elem in model.meta[:last_index + 1] if elem.readout is not None)
            model.slice(slice(last_index + 1, None))
            # add predicted values for pivot point to GP
            model.update_grad(
                x_pivot,
                d_grad_pivot.mean,
                d_grad_pivot.var,
                meta=[MeasureMeta(step=model.meta[0].step, goal=MeasureGoal.DISTILL, readout=discarded_readout)]
            )


class GaussianProcess:
    '''A gaussian process with noise using cholesky decomposition.
    The mean function is constant zero.

    Parameters
    ----------
    x_train : :py:obj:`torch.Tensor`
        Initial training sample inputs.
    y_train : :py:obj:`torch.Tensor`
        Initial training sample function values.
    kernel : :py:obj:`torch.Tensor`
        The covariance function.
    y_var : :py:obj:`torch.Tensor`
        Observation noise (variance) parameter. Default is `y_var_default`.
    y_var_default : float
        Default observation noise (variance) parameter. Default is ``0.1``.

    Attributes
    ----------
    x_train : :py:obj:`torch.Tensor`
        Current training sample inputs.
    y_train : :py:obj:`torch.Tensor`
        Current training sample function values.
    kernel : :py:obj:`torch.Tensor`
        The covariance function.
    covar : :py:obj:`torch.Tensor`
        Current prior covariance matrix.
    cholesky : :py:obj:`torch.Tensor`
        Current lower triangular cholesky decomposition of `covar`.
    cov_inv_y : :py:obj:`torch.Tensor`
        Current solution x for Kx = y, where K is the covariance matrix and y are the training sample function values.
    y_var : :py:obj:`torch.Tensor`
        Observation noise (variance) parameter. Default is `y_var_default`.
    y_var_default : float
        Default observation noise (variance) parameter. Default is ``0.1``.
    inducer : callable, optional
        Function to select inducing points. Default is None.

    '''
    _state_attributes = (
        'x_train',
        'y_train',
        'kernel',
        'covar',
        'cholesky',
        'cov_inv_y',
        'y_var',
        'y_var_default',
        'meta',
    )

    def __init__(self, x_train, y_train, kernel, y_var=None, y_var_default=0.1, inducer=None, meta=None):
        self.inducer = inducer
        self.initialize(x_train, y_train, kernel, y_var, y_var_default, meta)

    def __repr__(self):
        return (
            f'{self.__class__.__name__}'
            f'(size={len(self.x_train):d}, kernel={self.kernel}, y_var_default={self.y_var_default:.2e})'
        )

    def __len__(self):
        return len(self.x_train)

    def __getitem__(self, key):
        if not isinstance(key, slice):
            raise IndexError('GaussianProcess only supports slicing!')
        start, stop, step = key.indices(len(self))

        instance = copy(self)
        instance.slice(key)

        return instance

    def slice(self, key):
        if not isinstance(key, slice):
            raise IndexError('Not a valid slice!')

        start, _, step = key.indices(len(self))

        for name in ('x_train', 'y_train', 'y_var', 'meta'):
            setattr(self, name, getattr(self, name)[key])

        if start != 0 or step != 1:
            self.reinit()
        else:
            self.shrink()

    def state_dict(self):
        def detach(obj):
            if isinstance(obj, torch.Tensor):
                return obj.detach().numpy()
            return obj
        state_dict = {key: detach(getattr(self, key)) for key in self._state_attributes}
        state_dict['kernel'] = state_dict['kernel'].serialize()
        state_dict['meta'] = json.dumps(state_dict['meta'])
        return state_dict

    def load_state_dict(self, state_dict):
        def attach(obj):
            if isinstance(obj, np.ndarray):
                return torch.from_numpy(obj).clone()
            return obj
        state_dict = state_dict.copy()
        state_dict['kernel'] = Kernel.deserialize(state_dict['kernel'])
        state_dict['meta'] = [
            MeasureMeta(step, MeasureGoal(goal), readout) for step, goal, readout in json.loads(state_dict['meta'])
        ]
        for key in self._state_attributes:
            setattr(self, key, attach(state_dict[key]))

    @classmethod
    def from_state_dict(cls, state_dict):
        instance = object.__new__(cls)
        instance.load_state_dict(state_dict)
        return instance

    @classmethod
    def from_h5(cls, fname, key):
        def load(group):
            return cls.from_state_dict({key: value[()] for key, value in group.items()})
        if isinstance(fname, h5py.Group):
            return load(fname[key])
        with h5py.File(fname, 'r') as fd:
            return load(fd[key])

    def initialize(self, x_train, y_train, kernel, y_var=None, y_var_default=0.1, meta=None):
        self.x_train = x_train
        self.y_train = y_train
        self.kernel = kernel
        self.y_var_default = y_var_default
        self.y_var = self.expand_y_var(len(x_train), y_var, diag_embed=False)
        self.covar = kernel(x_train, x_train) + torch.diag_embed(self.y_var)
        self.cholesky = torch.linalg.cholesky(self.covar, upper=False)
        self.cov_inv_y = torch.cholesky_solve(self.y_train[:, None], self.cholesky, upper=False)
        if meta is None:
            meta = [MeasureMeta()] * len(x_train)
        self.meta = meta

    def posterior(self, x_test, noise_level=None, diag=False, y_train=None):
        '''Compute the predictive posterior distribution gradient.

        Parameters
        ----------
        x_test : :py:obj:`torch.Tensor`
            Sample values for which to compute the predictive posterior.
        noise_level : float, optional
            Independent noise as gaussian variance to be added to the distribution.
        diag : bool, optional
            Wether only the diagonal of the covariance should be computed. Default is False.
        y_train : :py:obj:`torch.Tensor` optional
            If supplied, compute the predictive posterior distribution under resampled values for y_train.

        Returns
        -------
        distribution : obj:`MultivariateNormal`
            The predictive posterior as a multivariate normal distribution.

        '''
        kernel_test = self.kernel(self.x_train, x_test)

        cov_inv_y = self.cov_inv_y
        if y_train is not None:
            # compute posterior for resampled y-values, needed for noisy expected improvement
            cov_inv_y = torch.cholesky_solve(y_train[..., None], self.cholesky, upper=False)

        post_mean = (kernel_test.transpose(-1, -2) @ cov_inv_y)[..., 0]
        cov_inv_test = torch.cholesky_solve(kernel_test, self.cholesky, upper=False)

        if diag:
            prior_covar = self.kernel.diag(x_test)
            post_covar = (prior_covar - (kernel_test * cov_inv_test).sum(-2))[..., None, :]
            if noise_level is not None:
                post_covar = post_covar + noise_level
        else:
            prior_covar = self.kernel(x_test, x_test)
            post_covar = prior_covar - kernel_test.transpose(-1, -2) @ cov_inv_test
            if noise_level is not None:
                # post_covar.diagonal()[()] += noise_level
                post_covar = post_covar + torch.diag_embed(torch.tensor((noise_level,) * len(x_test)))
        return MultivariateNormal(post_mean, post_covar)

    def posterior_grad(self, x_test, noise_level=None, diag=False, y_train=None):
        '''Compute the predictive posterior distribution gradient.

        Parameters
        ----------
        x_test : :py:obj:`torch.Tensor`
            Sample values for which to compute the predictive posterior.
        noise_level : float, optional
            Independent noise as gaussian variance to be added to the distribution.
        diag : bool, optional
            Wether only the diagonal of the covariance should be computed. Default is False.
        y_train : :py:obj:`torch.Tensor` optional
            If supplied, compute the predictive posterior distribution under resampled values for y_train.

        Returns
        -------
        distribution : obj:`MultivariateNormal`
            The predictive posterior as a multivariate normal distribution.

        '''
        # TODO: this is a temporary method!!
        kernel_test = self.kernel.grad(self.x_train, x_test)

        cov_inv_y = self.cov_inv_y
        if y_train is not None:
            # compute posterior for resampled y-values, needed for noisy expected improvement
            cov_inv_y = torch.cholesky_solve(y_train[..., None], self.cholesky, upper=False)

        post_mean = (kernel_test.transpose(-1, -2) @ cov_inv_y)[..., 0]
        cov_inv_test = torch.cholesky_solve(kernel_test, self.cholesky, upper=False)

        if diag:
            prior_covar = self.kernel.grad2.diag(x_test)
            post_covar = (prior_covar - (kernel_test * cov_inv_test).sum(-2))[..., None, :]
            if noise_level is not None:
                post_covar = post_covar + noise_level
        else:
            prior_covar = self.kernel.grad2(x_test, x_test)
            post_covar = prior_covar - kernel_test.transpose(-1, -2) @ cov_inv_test
            if noise_level is not None:
                # post_covar.diagonal()[()] += noise_level
                post_covar = post_covar + torch.diag_embed(torch.tensor((noise_level,) * len(x_test)))

        return MultivariateNormal(post_mean, post_covar)

    def peek_posterior(self, x_peek, x_test, y_peek=None, y_var=None, noise_level=None, diag=False):
        '''Compute the predictive posterior distribution.

        Parameters
        ----------
        x_peek : :py:obj:`torch.Tensor`
            Sample values assumed to be known.
        x_test : :py:obj:`torch.Tensor`
            Sample values for which to compute the predictive posterior.
        noise_level : float, optional
            Independent noise as gaussian variance to be added to the distribution.
        y_var : float or :py:obj:`torch.Tensor`, optional
            Observation noise (variance) parameter. If :py:obj=`None`, use default GP observation noise.
        diag : bool, optional
            Whether only the diagonal of the covariance should be computed. Default is False.

        Returns
        -------
        distribution : obj:`MultivariateNormal`
            The predictive posterior as a multivariate normal distribution.

        '''
        # const wrt. peek
        kernel_peek = self.kernel(self.x_train, x_peek)
        peek_covar = self.kernel(x_peek, x_peek)

        peek_covar = peek_covar + self.expand_y_var(peek_covar.shape[-1], y_var, diag_embed=True)
        L21, L22 = self.cholesky_append(kernel_peek, peek_covar)

        # const wrt. test
        kernel_train_test = self.kernel(self.x_train, x_test)
        kernel_peek_test = self.kernel(x_peek, x_test)

        # const wrt. test
        z_train = torch.linalg.solve_triangular(self.cholesky, kernel_train_test, upper=False)

        z_peek = torch.linalg.solve_triangular(L22, kernel_peek_test - L21 @ z_train, upper=False)
        cit_peek = torch.linalg.solve_triangular(L22.transpose(-1, -2), z_peek, upper=True)

        cit_train = torch.linalg.solve_triangular(
            self.cholesky.transpose(-1, -2),
            z_train - L21.transpose(-1, -2) @ cit_peek,
            upper=True
        )

        if diag:
            # const wrt. test
            prior_covar = self.kernel.diag(x_test)
            post_covar = (
                prior_covar
                - (kernel_train_test * cit_train).sum(-2)
                - (kernel_peek_test * cit_peek).sum(-2)
            )[..., None, :]
            if noise_level is not None:
                post_covar = post_covar + noise_level
        else:
            # const wrt. test
            prior_covar = self.kernel(x_test, x_test)
            post_covar = (
                prior_covar
                - kernel_train_test.transpose(-1, -2) @ cit_train
                - kernel_peek_test.transpose(-1, -2) @ cit_peek
            )
            if noise_level is not None:
                post_covar = post_covar + torch.diag_embed(torch.tensor((noise_level,) * len(x_test)))

        if y_peek is not None:
            c_train = torch.linalg.solve_triangular(self.cholesky, self.y_train[:, None], upper=False)

            c_peek = torch.linalg.solve_triangular(L22, y_peek[:, None] - L21 @ c_train, upper=False)
            ciy_peek = torch.linalg.solve_triangular(L22.transpose(-1, -2), c_peek, upper=True)

            ciy_train = torch.linalg.solve_triangular(
                self.cholesky.transpose(-1, -2),
                c_train - L21.transpose(-1, -2) @ ciy_peek,
                upper=True
            )
            post_mean = (
                kernel_train_test.transpose(-1, -2) @ ciy_train
                + kernel_peek_test.transpose(-1, -2) @ ciy_peek
            ).squeeze(-1)
        else:
            post_mean = torch.zeros(
                (*post_covar.shape[:-2], post_covar.shape[-1]),
                dtype=post_covar.dtype,
                device=post_covar.device
            )

        return MultivariateNormal(post_mean, post_covar)

    def peek_posterior_grad(self, x_peek, x_test, y_peek=None, y_var=None, noise_level=None, diag=False):
        '''Compute the predictive posterior distribution.

        Parameters
        ----------
        x_peek : :py:obj=`torch.Tensor`
            Sample values assumed to be known.
        x_test : :py:obj=`torch.Tensor`
            Sample values for which to compute the predictive posterior.
        noise_level : float, optional
            Independent noise as gaussian variance to be added to the distribution.
        y_var : float or :py:obj=`torch.Tensor`, optional
            Observation noise (variance) parameter. If :py:obj=`None`, use default GP observation noise.
        diag : bool, optional
            Whether only the diagonal of the covariance should be computed. Default is False.

        Returns
        -------
        distribution : obj:`MultivariateNormal`
            The predictive posterior as a multivariate normal distribution.

        '''
        peek_covar = self.kernel(x_peek, x_peek)
        peek_covar = peek_covar + self.expand_y_var(peek_covar.shape[-1], y_var, diag_embed=True)
        peek_cholesky = torch.linalg.cholesky(peek_covar, upper=False)
        kernel_peek_test = self.kernel.grad(x_peek, x_test)
        cov_inv_test = torch.cholesky_solve(kernel_peek_test[:, :, None], peek_cholesky, upper=False)

        if diag:
            prior_covar = self.kernel.grad2.diag(x_test)
            post_covar = (prior_covar - (kernel_peek_test[:, :, None] * cov_inv_test).sum(-2))[..., None, :]
            if noise_level is not None:
                post_covar = post_covar + noise_level
        else:
            prior_covar = self.kernel.grad2(x_test, x_test)
            post_covar = prior_covar - kernel_peek_test.transpose(-1, -2) @ cov_inv_test
            if noise_level is not None:
                post_covar = post_covar + torch.diag_embed(torch.tensor((noise_level,) * len(x_test)))

        if y_peek is not None:
            cov_inv_y = torch.cholesky_solve(y_peek[..., None], peek_cholesky, upper=False)
            post_mean = (kernel_peek_test.transpose(-1, -2) @ cov_inv_y)[..., 0]
        else:
            post_mean = torch.zeros(
                (*post_covar.shape[:-2], post_covar.shape[-1]),
                dtype=post_covar.dtype,
                device=post_covar.device
            )
        return MultivariateNormal(post_mean, post_covar)

    def prior(self):
        '''Return the prior as a multivariate normal distribution.

        Returns
        -------
        distribution : obj:`MultivariateNormal`
            The prior as a multivariate normal distribution.
        '''
        return MultivariateNormal(self.y_train, self.covar, cholesky=self.cholesky)

    def expand_y_var(self, n_samples, y_var=None, diag_embed=False, shape=None):
        '''Expand the measurement noise for a number of samples.

        Parameters
        ----------
        n_samples : int
            Number of samples to which to add the observation noise.
        y_var : float or :py:obj:`torch.Tensor`, optional
            Observation noise (variance) parameter. If :py:obj:`None`, use default GP observation noise.
        diag_embed : bool, optional
            If ``False``, return `y_var` in the same shape as the samples. Otherwise, it is embedded as the diagonal of
            a square matrix.
        shape : tuple, optional
            Shape to expand each entry of `y_var` to, if provided.

        Returns
        -------
        y_var_expanded : obj:`torch.Tensor`
            The expanded `y_var`, either flat, or embedded into the diagonal of a square matrix.
        '''
        if y_var is None:
            y_var = self.y_var_default

        if isinstance(y_var, (int, float)) or not y_var.shape:
            y_var = torch.full((n_samples,), y_var)

        if y_var.shape[-1] == 1:
            y_var = y_var.expand((*y_var.shape[:-1], n_samples))

        if diag_embed:
            return torch.diag_embed(y_var)

        return y_var

    def update(self, x_cand, y_cand, y_var=None, meta=None):
        """ Update the training points with new candidate point(s).

        Parameters
        ----------
        x_cand : :py:obj=`torch.Tensor`
            New (batch of) candidate points to be added.
        y_cand : :py:obj=`torch.Tensor` optional
            Function values of candidate points to be added.
        y_var : float, optional
            Observation noise (variance) parameter for peeked observations. Default is the GP's observation noise.
        """
        dim = len(self.covar)
        dnew = len(x_cand)
        kernel_both = self.kernel(self.x_train, x_cand)
        kernel_cand = self.kernel(x_cand, x_cand)

        y_var = self.expand_y_var(kernel_cand.shape[-1], y_var, diag_embed=False)

        kernel_cand = kernel_cand + torch.diag_embed(y_var)

        covar = torch.empty((dim + dnew,) * 2, dtype=self.covar.dtype)
        covar[:dim, :dim] = self.covar
        covar[:dim, dim:] = kernel_both
        covar[dim:, :dim] = kernel_both.t()
        covar[dim:, dim:] = kernel_cand

        L21, L22 = self.cholesky_append(kernel_both, kernel_cand)

        chol = torch.empty((dim + dnew,) * 2, dtype=self.cholesky.dtype)
        chol[:dim, :dim] = self.cholesky
        chol[:dim, dim:] = 0.
        chol[dim:, :dim] = L21
        chol[dim:, dim:] = L22

        y_train = torch.cat((self.y_train, y_cand), dim=0)
        cov_inv_y = torch.cholesky_solve(y_train[:, None], chol, upper=False)

        self.x_train = torch.cat((self.x_train, x_cand), dim=0)
        self.y_train = y_train

        self.y_var = torch.cat((self.y_var, y_var), dim=0)

        self.covar = covar
        self.cholesky = chol
        self.cov_inv_y = cov_inv_y

        if meta is None:
            meta = [MeasureMeta()] * len(x_cand)
        self.meta += meta

        if self.inducer is not None:
            self.inducer(self)

    def update_grad(self, x_cand, y_cand, y_var=None, meta=None):
        """ Update the training points with new candidate point(s).

        Parameters
        ----------
        x_cand : :py:obj=`torch.Tensor`
            New (batch of) candidate points to be added.
        y_cand : :py:obj=`torch.Tensor` optional
            Gradient values of candidate points to be added.
        y_var : float, optional
            Observation noise (variance) parameter for peeked observations. Default is the GP's observation noise.
        """
        dim = len(self.covar)
        dnew = len(x_cand)
        kernel_both = self.kernel.grad(x_cand, self.x_train)
        kernel_cand = self.kernel.grad2.diag(x_cand)

        y_var = self.expand_y_var(kernel_cand.shape[-1], y_var, diag_embed=False)

        kernel_cand = kernel_cand + torch.diag_embed(y_var)

        covar = torch.empty((dim + dnew,) * 2, dtype=self.covar.dtype)
        covar[:dim, :dim] = self.covar
        covar[:dim, dim:] = kernel_both
        covar[dim:, :dim] = kernel_both.t()
        covar[dim:, dim:] = kernel_cand

        L21, L22 = self.cholesky_append(kernel_both, kernel_cand)

        chol = torch.empty((dim + dnew,) * 2, dtype=self.cholesky.dtype)
        chol[:dim, :dim] = self.cholesky
        chol[:dim, dim:] = 0.
        chol[dim:, :dim] = L21
        chol[dim:, dim:] = L22

        y_train = torch.cat((self.y_train, y_cand), dim=0)
        cov_inv_y = torch.cholesky_solve(y_train[:, None], chol, upper=False)

        self.x_train = torch.cat((self.x_train, x_cand), dim=0)
        self.y_train = y_train

        self.y_var = torch.cat((self.y_var, y_var), dim=0)

        self.covar = covar
        self.cholesky = chol
        self.cov_inv_y = cov_inv_y

        if meta is None:
            meta = [MeasureMeta()] * len(x_cand)
        self.meta += meta

        if self.inducer is not None:
            self.inducer(self)

    def reinit(self):
        self.initialize(
            self.x_train,
            self.y_train,
            self.kernel,
            y_var=self.y_var,
            y_var_default=self.y_var_default,
            meta=self.meta
        )

    def shrink(self):
        self.covar = self.covar[..., :len(self), :len(self)]
        self.cholesky = self.cholesky[..., :len(self), :len(self)]
        # this can be done twice as fast using the approach from peek_posterior
        self.cov_inv_y = torch.cholesky_solve(self.y_train[:, None], self.cholesky, upper=False)

    def log_likelihood(self):
        term_1 = 0.5 * self.y_train.t() @ self.cov_inv_y
        term_2 = self.cholesky.diagonal().log().sum()
        term_3 = 0.5 * len(self.y_train) * math.log(2 * math.pi)
        return -(term_1 + term_2 + term_3)

    def grad_log_likelihood(self, kernel_grad):
        term_1 = self.cov_inv_y.t() @ self.cov_inv_y @ kernel_grad
        term_2 = torch.cholesky_solve(kernel_grad, self.cholesky, upper=False)
        return 0.5 * (term_1.trace() - term_2.trace())

    def loocv_mll_closed(self):
        kernel_inv_diag = torch.diag(torch.cholesky_inverse(self.cholesky, upper=False))
        loo_mu = self.y_train - self.cov_inv_y.squeeze(1) / kernel_inv_diag
        loo_var = 1 / kernel_inv_diag
        return (-0.5 * (loo_var.log() + (self.y_train - loo_mu) ** 2 / loo_var)).sum()

    def cholesky_append(self, kernel_peek, peek_covar):
        L21 = torch.linalg.solve_triangular(self.cholesky, kernel_peek, upper=False).transpose(-2, -1)
        pred_cov = peek_covar - L21 @ L21.transpose(-2, -1)
        if (torch.linalg.matrix_rank(pred_cov, atol=1e-10) < pred_cov.shape[-1]).any():
            raise SingularGramError('Updated Gram matrix is singular!')
        try:
            L22 = torch.linalg.cholesky(pred_cov, upper=False)
        except RuntimeError as error:
            raise SingularGramError('Updated Gram matrix is singular!') from error

        return L21, L22
